notebooks/Unit 9 - Model Optimization/torch.ipynb (318 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pytorch Quantization\n", "\n", "PyTorch supports INT8 quantization compared to typical FP32 models allowing for a 4x reduction in the model size and a 4x reduction in memory bandwidth requirements\n", "while still achieving comparable accuracy for many applications. This notebook demonstrates how to quantize a model from FP32 to INT8 using PyTorch's quantization tooling. We will train a simple CNN model on mnist and then quantize it using the quantization tooling and compare the accuracy and size of the quantized model with the original FP32 model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup PyTorch\n", "\n", "First, let's install PyTorch and torchvision and the import the required modules.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install torch torchvision" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "import torch.quantization\n", "import pathlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dynamic Quantization\n", "\n", "For dynamic quantization, weights are quantized but activations are read or stored in floating point and the activations are only quantized for compute." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load MNIST dataset \n", "\n", "First, we load the MNIST dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])\n", "\n", "train_dataset = datasets.MNIST('./data', train=True, download=True,transform=transform)\n", "test_dataset = datasets.MNIST('./data', train=False,transform=transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the Model\n", "\n", "Next, we define a simple CNN model and then train on the MNIST dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=3)\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.fc = nn.Linear(12 * 13 * 13, 10)\n", "\n", " def forward(self, x):\n", " x = x.reshape(-1, 1, 28, 28) \n", " x = F.relu(self.conv1(x))\n", " x = self.pool(x)\n", " x = x.reshape(x.size(0), -1) \n", " x = self.fc(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", "\n", "\n", "train_loader = torch.utils.data.DataLoader(train_dataset, 32)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, 32)\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "epochs = 1\n", "\n", "model = Net().to(device)\n", "optimizer = optim.Adam(model.parameters())\n", "\n", "model.train()\n", "\n", "for epoch in range(1, epochs+1):\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Quantize Model\n", "\n", "After training, we can quantize the model using the using the `torch.quantization.quantize_dynamic` function from pytorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.to('cpu')\n", "quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check Model Size\n", "\n", "We can see that the quantized model is much smaller than the original model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "models_dir = pathlib.Path(\"./models/\")\n", "models_dir.mkdir(exist_ok=True, parents=True)\n", "torch.save(model.state_dict(), \"./models/original_model.p\")\n", "torch.save(quantized_model.state_dict(), \"./models/quantized_model.p\")\n", "\n", "%ls -lh models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check Accuracy\n", "\n", "We can see that the quantized model has comparable accuracy to the original model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(model, device, data_loader, quantized=False):\n", " model.to(device)\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " with torch.no_grad():\n", " for data, target in data_loader:\n", " data, target = data.to(device), target.to(device)\n", " output = model(data)\n", " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " test_loss /= len(data_loader.dataset)\n", "\n", " return 100. * correct / len(data_loader.dataset)\n", "\n", "original_acc = test(model, \"cpu\", test_loader)\n", "quantized_acc = test(quantized_model, \"cpu\", test_loader)\n", "\n", "print('Original model accuracy: {:.0f}%'.format(original_acc))\n", "print('Quantized model accuracy: {:.0f}%'.format(quantized_acc))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Post-training Static Quantization\n", "\n", "Post-training static quantization is where weights and activations are quantized and calibration is required post training. Here we quantize the model using the `torch.quantization.quantize_fx()` function from PyTorch and compare the accuracy and size of the quantized model with the original FP32 model.\n", "\n", "To quantize using post-training static quantization tool, first define a model or load a pre-trained model and then create quantization configuration mapping using the default for the QNNPACK engine. Set the model to evaluation mode and create a sample input tensor. Then, prepare the model for quantization using the `quantize_fx.prepare_fx()` function. This involves applying the quantization configuration mapping and preparing the model to handle int8 precision. The prepared model is then executed on the input tensor. Finally, the quantized model by calling`quantize_fx.convert_fx()` and saved the model to disk.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.ao.quantization import (\n", " get_default_qconfig_mapping,\n", " get_default_qat_qconfig_mapping,\n", " QConfigMapping,\n", ")\n", "import torch.ao.quantization.quantize_fx as quantize_fx\n", "import copy\n", "\n", "loaded_model = Net()\n", "loaded_model.load_state_dict(torch.load(\"./models/original_model.p\"))\n", "model_to_quantize = copy.deepcopy(loaded_model)\n", "\n", "qconfig_mapping = get_default_qconfig_mapping(\"qnnpack\")\n", "model_to_quantize.eval()\n", "\n", "input_fp32 = next(iter(test_loader))[0][0:1]\n", "input_fp32.to('cpu')\n", "\n", "model_fp32_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, input_fp32)\n", "model_fp32_prepared(input_fp32)\n", "model_int8 = quantize_fx.convert_fx(model_fp32_prepared)\n", "\n", "torch.save(model_int8.state_dict(), \"./models/post_quantized_model.p\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check Model Size\n", "\n", "Again, we can see that the quantized model is much smaller than the original model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%ls -lh models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check Accuracy\n", "\n", "Again, we can see that the quantized model accuarcy is not much difference than the original accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "quantized_acc = test(model_int8, \"cpu\", test_loader, quantized=True)\n", "print('Post quantized model accuracy: {:.0f}%'.format(quantized_acc))" ] } ], "metadata": { "kernelspec": { "display_name": "model_optimization", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }